#!/bin/bash
weak_model_size="14m"
strong_model_size="410m"
num_epochs=1
glue_ds="advgluepp"

weak_model="EleutherAI/pythia-${weak_model_size}"
strong_model="EleutherAI/pythia-${strong_model_size}"
weak_model_size_str=$(echo $weak_model_size | sed 's/\./p/g')
strong_model_size_str=$(echo $strong_model_size | sed 's/\./p/g')

tasks=("sst2" "qqp" "mnli" "mnli-mm" "qnli" "rte")

for run in 1 6 11
do
    for lambda_coeff in 0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0
    do
        lambda_coeff_str=$(echo $lambda_coeff | sed 's/\./p/g')
        results_file="results/${glue_ds}/hyperparam_wts/lambda-${lambda_coeff_str}/pythia-${weak_model_size_str}-${strong_model_size_str}-ep${num_epochs}-${run}.json"
        for task in "${tasks[@]}"
        do
            # Weak fine-tuning
            python wts.py \
                --model_id $weak_model \
                --results_file $results_file \
                --task $task \
                --ft_mode weak \
                --num_epochs $num_epochs \
                --lambda_coeff 0.3 \
                --validate_on adversarial \
                --tag "-${run}" \
                --glue_ds $glue_ds
            
            # Naive WTS fine-tuning (Phase 3)
            python wts.py \
                --model_id $strong_model \
                --results_file $results_file \
                --task $task \
                --ft_mode wts-naive \
                --weak_labels_file "weak_labels/${glue_ds}/${weak_model}-${run}.json" \
                --num_epochs $num_epochs \
                --lambda_coeff $lambda_coeff \
                --validate_on adversarial \
                --tag "-${run}" \
                --glue_ds $glue_ds

            # Strong fine-tuning
            # python wts.py \
            #     --model_id $strong_model \
            #     --results_file $results_file \
            #     --task $task \
            #     --ft_mode strong \
            #     --num_epochs $num_epochs \
            #     --lambda_coeff $lambda_coeff \
            #     --validate_on adversarial \
            #     --tag "-${run}" \
            #     --glue_ds $glue_ds
        done
    done
done
